package com.basic

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.immutable.Nil

object UDAFDemo {
  def main(args: Array[String]): Unit = {
    // 在sql中, 聚合函数如何使用
    val spark: SparkSession = SparkSession
      .builder()
      .master("local[*]")
      .appName("UDAFDemo")
      .getOrCreate()
    val df = spark.read.json("E:\\ZJJ_SparkSQL\\demo01\\src\\main\\resources\\users.json")
    df.createOrReplaceTempView("user")
    // 注册聚合函数
    spark.udf.register("myAvg", new MyAvg)
    spark.sql("select myAvg(age) from user").show
    spark.close()
  }
}

/**
 * 求平均值
 */
class MyAvg extends UserDefinedAggregateFunction {
  // 输入的数据类型  10.1 12.2 100
  override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)

  // 缓冲区的类型
  // 求平均值需要两个值运算,一个是年龄的和,另外一个是多少个年龄参与运算,
  // 所以这就是求平均值了.
  override def bufferSchema: StructType =
    StructType(StructField("sum", DoubleType) :: StructField("count", LongType) :: Nil)

  // 最终聚合解结果的类型
  override def dataType: DataType = DoubleType

  // 相同的输入是否返回相同的输出
  override def deterministic: Boolean = true

  // 对缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    // avg初始化是一个值,个数初始化也得是一个值.

    // 在缓冲集合中初始化和
    buffer(0) = 0D // 等价于 buffer.update(0, 0D)
    buffer(1) = 0L //带个L,不然就存成int类型了.
  }

  // 分区内聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(age: Double) =>
        buffer(0) = buffer.getDouble(0) + age //年龄进行相加
        buffer(1) = buffer.getLong(1) + 1L //个数累加碰到一个就加1
      case _ =>
    }

    /*// input是指的使用聚合函数的时候, 缓过来的参数封装到了Row
    if (!input.isNullAt(0)) { // 考虑到传字段可能是null
        val v = input.getAs[Double](0) // getDouble(0)
        buffer(0) = buffer.getDouble(0) + v
        buffer(1) = buffer.getLong(1) + 1L
    }*/
  }

  // 分区间的聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer2 match {
      case Row(sum: Double, count: Long) =>
        // 缓冲区和要集合
        buffer1(0) = buffer1.getDouble(0) + sum
        //个数也要聚合
        buffer1(1) = buffer1.getLong(1) + count

      case _ =>
    }


    // 把buffer1和buffer2 的缓冲弄聚合到一起, 然后再把值写回到buffer1
    /*buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)*/
  }

  // 返回最终的输出值
  // 就是累加的总数除以个数,就是平均值了.
  override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getLong(1)
}

/**
 * 求和的函数
 */
class MySum extends UserDefinedAggregateFunction {
  // 输入的数据类型  10.1 12.2 100
  override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)

  // 缓冲区的类型,求和的时候需要缓冲,计算聚合的是一定要有缓冲
  override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: Nil)

  // 最终聚合解结果的类型
  //因为你是聚合,聚合的结果只有一个
  override def dataType: DataType = DoubleType

  // 相同的输入是否返回相同的输出
  //用的时候几乎永远都是true
  override def deterministic: Boolean = true

  // 对缓冲区初始化
  //初始化的就是在缓冲里面去初始化一个值,用来计算聚合的值的.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    // 在缓冲集合中初始化和
    buffer(0) = 0D // 等价于 buffer.update(0, 0D)
  }

  // 分区内聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // 需要先进行非空判断.
    // 如果不为空的话就取出来进行计算.

    input match {
      case Row(age: Double) =>
        // 获取0位置的
        buffer(0) = buffer.getDouble(0) + age

      case _ =>
    }

    /*// input是指的使用聚合函数的时候, 缓过来的参数封装到了Row
    if (!input.isNullAt(0)) { // 考虑到传字段可能是null
        val v = input.getAs[Double](0) // getDouble(0)
        buffer(0) = buffer.getDouble(0) + v
    }*/
  }

  // 分区间的聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // 把buffer1和buffer2 的缓冲弄聚合到一起, 然后再把值写回到buffer1
    //这里不需要判断非空,因为缓冲区初始化(initialize方法)是一定有值的.
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)

  }

  // 集合函数返回最终的输出值
  override def evaluate(buffer: Row): Any = buffer.getDouble(0)
}